{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JCkPHee-alKM"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "cvrknnPsaqaM"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2Zt0JEl38SWX"
      },
      "source": [
        "# Magnitude-based weight pruning with Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IwBSTsryazfT"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "L7IGLmaw5DjA"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Welcome to the tutorial for *weight pruning*, part of the TensorFlow Model Optimization toolkit.\n",
        "\n",
        "#### What is weight pruning?\n",
        "\n",
        "Weight pruning means literally that: eliminating unnecessary values in the weight tensor. We are practically setting neural network parameters' values to zero to remove low-weight connections between the *layers* of a neural network.\n",
        "\n",
        "#### Why is useful?\n",
        "\n",
        "Tensors with several values set to zero can be considered *sparse*. This results in important benefits:\n",
        "* *Compression*. Sparse tensors are amenable to compression by only keeping the non-zero values and their corresponding coordinates.\n",
        "* *Speed*. Sparse tensors allow us to skip otherwise unnecessary computations involving the zero values.\n",
        "\n",
        "#### How does it work?\n",
        "\n",
        "Our Keras-based weight pruning API is designed to iteratively remove connections based on their magnitude, during training. For more details on the usage of the API, please refer to the GitHub page.\n",
        "\n",
        "In this tutorial, we'll walk you through an end-to-end example of using the weight pruning API on a simple MNIST model. We will show that by simply using a generic file compression algorithm (e.g. zip) the Keras model will be reduced in size, and that this size reduction persists when converted to a Tensorflow Lite format.\n",
        "\n",
        "Two things worth clarifying:\n",
        "* The technique and API are not TensorFlow Lite specific --we just show its application on the TensorFlow Lite backend, as it covers size-sensitive use-cases.\n",
        "* By itself, a sparse model will not be faster to execute. It just enables backends with such capability. In the near future, however, TensorFlow Lite will take advantage of the sparsity to speed up computations.\n",
        "\n",
        "To recap, in the tutorial we will:\n",
        "1.   Train a MNIST model with Keras from scratch.\n",
        "2.   Train a pruned MNIST with the pruning API.\n",
        "3.   Compare the size of the pruned model and the non-pruned one after compression.\n",
        "4.   Convert the pruned model to Tensorflow Lite format and verify that accuracy persists.\n",
        "5.   Show how the pruned model works with other optimization techniques, like post-training quantization."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "P8qFbkru8FKu"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "q8zIQsT9mUTw"
      },
      "source": [
        "To use the pruning API, install the `tensorflow-model-optimization` package. See the [TensorFlow model optimization repo](https://github.com/tensorflow/model-optimization) for compatible API versions.\n",
        "\n",
        "Since you will train a few models in this tutorial, install the `tensorflow-gpu` package to speed up things. Enable the GPU with: *Runtime \u003e Change runtime type \u003e Hardware accelator* and make sure GPU is selected."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Pn836LSTNSHA"
      },
      "outputs": [],
      "source": [
        "! pip uninstall -y tensorflow\n",
        "! pip uninstall -y tf-nightly\n",
        "! pip install -U tensorflow-gpu==1.14.0\n",
        "\n",
        "! pip install tensorflow-model-optimization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1ykjgo4UNXmD"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard\n",
        "import tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mydXeQlDNbnR"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "import tempfile\n",
        "import zipfile\n",
        "import os"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gBYltugp-MdR"
      },
      "source": [
        "## Prepare the training data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "P_MJqxz5z2dh"
      },
      "outputs": [],
      "source": [
        "batch_size = 128\n",
        "num_classes = 10\n",
        "epochs = 10\n",
        "\n",
        "# input image dimensions\n",
        "img_rows, img_cols = 28, 28\n",
        "\n",
        "# the data, shuffled and split between train and test sets\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "\n",
        "if tf.keras.backend.image_data_format() == 'channels_first':\n",
        "  x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n",
        "  x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n",
        "  input_shape = (1, img_rows, img_cols)\n",
        "else:\n",
        "  x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
        "  x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
        "  input_shape = (img_rows, img_cols, 1)\n",
        "\n",
        "x_train = x_train.astype('float32')\n",
        "x_test = x_test.astype('float32')\n",
        "x_train /= 255\n",
        "x_test /= 255\n",
        "print('x_train shape:', x_train.shape)\n",
        "print(x_train.shape[0], 'train samples')\n",
        "print(x_test.shape[0], 'test samples')\n",
        "\n",
        "# convert class vectors to binary class matrices\n",
        "y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n",
        "y_test = tf.keras.utils.to_categorical(y_test, num_classes)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OmdnYKPK--5L"
      },
      "source": [
        "## Train a MNIST model without pruning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "r1p_y3gPWeW2"
      },
      "source": [
        "### Build the MNIST model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "uLg0SGdYp2Q6"
      },
      "outputs": [],
      "source": [
        "l = tf.keras.layers\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "    l.Conv2D(\n",
        "        32, 5, padding='same', activation='relu', input_shape=input_shape),\n",
        "    l.MaxPooling2D((2, 2), (2, 2), padding='same'),\n",
        "    l.BatchNormalization(),\n",
        "    l.Conv2D(64, 5, padding='same', activation='relu'),\n",
        "    l.MaxPooling2D((2, 2), (2, 2), padding='same'),\n",
        "    l.Flatten(),\n",
        "    l.Dense(1024, activation='relu'),\n",
        "    l.Dropout(0.4),\n",
        "    l.Dense(num_classes, activation='softmax')\n",
        "])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5dHynoso_xXF"
      },
      "source": [
        "### Train the model to reach an accuracy \u003e99%"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KGo0UICmA8J3"
      },
      "source": [
        "\n",
        "Load [TensorBoard](https://www.tensorflow.org/tensorboard) to monitor the training process"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WJyS_IcQBBqg"
      },
      "outputs": [],
      "source": [
        "logdir = tempfile.mkdtemp()\n",
        "print('Writing training logs to ' + logdir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QfgjuOLj_mFu"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir={logdir}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "A7c-fsQpTzOr"
      },
      "outputs": [],
      "source": [
        "callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]\n",
        "\n",
        "model.compile(\n",
        "    loss=tf.keras.losses.categorical_crossentropy,\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "model.fit(x_train, y_train,\n",
        "          batch_size=batch_size,\n",
        "          epochs=epochs,\n",
        "          verbose=1,\n",
        "          callbacks=callbacks,\n",
        "          validation_data=(x_test, y_test))\n",
        "score = model.evaluate(x_test, y_test, verbose=0)\n",
        "print('Test loss:', score[0])\n",
        "print('Test accuracy:', score[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eLgMavDVCde5"
      },
      "source": [
        "### Save the original model for size comparison later"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8E6U7GUIx5r9"
      },
      "outputs": [],
      "source": [
        "# Backend agnostic way to save/restore models\n",
        "_, keras_file = tempfile.mkstemp('.h5')\n",
        "print('Saving model to: ', keras_file)\n",
        "tf.keras.models.save_model(model, keras_file, include_optimizer=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gWiQbobKC0NZ"
      },
      "source": [
        "## Train a pruned MNIST\n",
        "\n",
        "We provide a `prune_low_magnitude()` API to train models with removed connections. The Keras-based API can be applied at the level of individual layers, or the entire model. We will show you the usage of both in the following sections.\n",
        "\n",
        "At a high level, the technique works by iteratively removing (i.e. zeroing out) connections between layers, given an schedule and a target sparsity.\n",
        "\n",
        "For example, a typical configuration will target a 75% sparsity, by pruning connections every 100 steps (aka epochs), starting from step 2,000. For more details on the possible configurations, please refer to the github documentation. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FfJ6Tm3KXMFY"
      },
      "source": [
        "### Build a pruned model layer by layer\n",
        "In this example, we show how to use the API at the level of layers, and build a pruned MNIST solver model.\n",
        "\n",
        "In this case, the `prune_low_magnitude(`) \n",
        "receives as parameter the Keras layer whose weights we want pruned.\n",
        "\n",
        "This function requires a pruning params which configures the pruning algorithm during training. Please refer to our github page for detailed documentation. The parameter used here means:\n",
        "\n",
        "\n",
        "1.   **Sparsity.** PolynomialDecay is used across the whole training process. We start at the sparsity level 50% and gradually train the model to reach 90% sparsity. X% sparsity means that X% of the weight tensor is going to be pruned away.\n",
        "2.   **Schedule**. Connections are pruned starting from step 2000 to the end of training, and runs every 100 steps. The reasoning behind this is that we want to train the model without pruning for a few epochs to reach a certain accuracy, to aid convergence. Furthermore, we give the model some time to recover after each pruning step, so pruning does not happen on every step. We set the pruning frequency to 100.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Ip8qCtSlU3TQ"
      },
      "outputs": [],
      "source": [
        "from tensorflow_model_optimization.sparsity import keras as sparsity"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tA9rnRUrebTE"
      },
      "source": [
        "To demonstrate how to save and restore a pruned keras model, in the following example we first train the model for 10 epochs, save it to disk, and finally restore and continue training for 2 epochs. With gradual sparsity, four important parameters are begin_sparsity, final_sparsity, begin_step and end_step. The first three are straight forward. Let's calculate the end step given the number of train example, batch size, and the total epochs to train."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rrs-xoB2cSSQ"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "epochs = 12\n",
        "num_train_samples = x_train.shape[0]\n",
        "end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs\n",
        "print('End step: ' + str(end_step))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Shz-r2RiqFca"
      },
      "outputs": [],
      "source": [
        "pruning_params = {\n",
        "      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,\n",
        "                                                   final_sparsity=0.90,\n",
        "                                                   begin_step=2000,\n",
        "                                                   end_step=end_step,\n",
        "                                                   frequency=100)\n",
        "}\n",
        "\n",
        "pruned_model = tf.keras.Sequential([\n",
        "    sparsity.prune_low_magnitude(\n",
        "        l.Conv2D(32, 5, padding='same', activation='relu'),\n",
        "        input_shape=input_shape,\n",
        "        **pruning_params),\n",
        "    l.MaxPooling2D((2, 2), (2, 2), padding='same'),\n",
        "    l.BatchNormalization(),\n",
        "    sparsity.prune_low_magnitude(\n",
        "        l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),\n",
        "    l.MaxPooling2D((2, 2), (2, 2), padding='same'),\n",
        "    l.Flatten(),\n",
        "    sparsity.prune_low_magnitude(l.Dense(1024, activation='relu'),\n",
        "                                 **pruning_params),\n",
        "    l.Dropout(0.4),\n",
        "    sparsity.prune_low_magnitude(l.Dense(num_classes, activation='softmax'),\n",
        "                                 **pruning_params)\n",
        "])\n",
        "\n",
        "pruned_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YczppQ6vEPJg"
      },
      "source": [
        "Load Tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eIO8UpZyEYkp"
      },
      "outputs": [],
      "source": [
        "logdir = tempfile.mkdtemp()\n",
        "print('Writing training logs to ' + logdir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "KKb8XpDkA8TN"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir={logdir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z2166laKE_N6"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "Start pruning from step 2000 when accuracy \u003e98%"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "GGoOTwQRzEP4"
      },
      "outputs": [],
      "source": [
        "pruned_model.compile(\n",
        "    loss=tf.keras.losses.categorical_crossentropy,\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "# Add a pruning step callback to peg the pruning step to the optimizer's\n",
        "# step. Also add a callback to add pruning summaries to tensorboard\n",
        "callbacks = [\n",
        "    sparsity.UpdatePruningStep(),\n",
        "    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)\n",
        "]\n",
        "\n",
        "pruned_model.fit(x_train, y_train,\n",
        "          batch_size=batch_size,\n",
        "          epochs=10,\n",
        "          verbose=1,\n",
        "          callbacks=callbacks,\n",
        "          validation_data=(x_test, y_test))\n",
        "\n",
        "score = pruned_model.evaluate(x_test, y_test, verbose=0)\n",
        "print('Test loss:', score[0])\n",
        "print('Test accuracy:', score[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8Nzrm5pN1viP"
      },
      "source": [
        "### Save and restore the pruned model\n",
        "\n",
        "Continue training for two epochs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rd8G7srV2dcI"
      },
      "outputs": [],
      "source": [
        "_, checkpoint_file = tempfile.mkstemp('.h5')\n",
        "print('Saving pruned model to: ', checkpoint_file)\n",
        "# saved_model() sets include_optimizer to True by default. Spelling it out here\n",
        "# to highlight.\n",
        "tf.keras.models.save_model(pruned_model, checkpoint_file, include_optimizer=True)\n",
        "\n",
        "with sparsity.prune_scope():\n",
        "  restored_model = tf.keras.models.load_model(checkpoint_file)\n",
        "\n",
        "callbacks = [\n",
        "    sparsity.UpdatePruningStep(),\n",
        "    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)\n",
        "]\n",
        "\n", 
        "restored_model.fit(x_train, y_train,\n",
        "                   batch_size=batch_size,\n",
        "                   epochs=2,\n",
        "                   verbose=1,\n",
        "                   callbacks=callbacks,\n",
        "                   validation_data=(x_test, y_test))\n",
        "\n",
        "score = restored_model.evaluate(x_test, y_test, verbose=0)\n",
        "print('Test loss:', score[0])\n",
        "print('Test accuracy:', score[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1vV7pZrW5TQh"
      },
      "source": [
        "In the example above, a few things to note are:\n",
        "\n",
        "\n",
        "*   When saving the model, include_optimizer must be set to True. We need to preserve the state of the optimizer across training sessions for pruning to work properly.\n",
        "*   When loading the pruned model, you need the prune_scope() for deseriazliation.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tMTFhyc0vAA3"
      },
      "source": [
        "### Strip the pruning wrappers from the pruned model before export for serving\n",
        "Before exporting a serving model, you'd need to call the `strip_pruning` API to strip the pruning wrappers from the model, as it's only needed for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jyCjjpUjvImz"
      },
      "outputs": [],
      "source": [
        "final_model = sparsity.strip_pruning(pruned_model)\n",
        "final_model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "B63tViKp_qLK"
      },
      "outputs": [],
      "source": [
        "_, pruned_keras_file = tempfile.mkstemp('.h5')\n",
        "print('Saving pruned model to: ', pruned_keras_file)\n",
        "\n",
        "# No need to save the optimizer with the graph for serving.\n",
        "tf.keras.models.save_model(final_model, pruned_keras_file, include_optimizer=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GhXHuBAGOBvY"
      },
      "source": [
        "### Compare the size of the unpruned vs. pruned model after compression"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Hk4DoZTIy2uU"
      },
      "outputs": [],
      "source": [
        "_, zip1 = tempfile.mkstemp('.zip') \n",
        "with zipfile.ZipFile(zip1, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
        "  f.write(keras_file)\n",
        "print(\"Size of the unpruned model before compression: %.2f Mb\" % \n",
        "      (os.path.getsize(keras_file) / float(2**20)))\n",
        "print(\"Size of the unpruned model after compression: %.2f Mb\" % \n",
        "      (os.path.getsize(zip1) / float(2**20)))\n",
        "\n",
        "_, zip2 = tempfile.mkstemp('.zip') \n",
        "with zipfile.ZipFile(zip2, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
        "  f.write(pruned_keras_file)\n",
        "print(\"Size of the pruned model before compression: %.2f Mb\" % \n",
        "      (os.path.getsize(pruned_keras_file) / float(2**20)))\n",
        "print(\"Size of the pruned model after compression: %.2f Mb\" % \n",
        "      (os.path.getsize(zip2) / float(2**20)))\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dayb_w_GqWs_"
      },
      "source": [
        "### Prune a whole model\n",
        "\n",
        "The `prune_low_magnitude` function can also be applied to the entire Keras model. \n",
        "\n",
        "In this case, the algorithm will be applied to all layers that are ameanable to weight pruning (that the API knows about). Layers that the API knows are not ameanable to weight pruning will be ignored, and unknown layers to the API will cause an error.\n",
        "\n",
        "*If your model has layers that the API does not know how to prune their weights, but are perfectly fine to leave \"un-pruned\", then just apply the API in a per-layer basis.*\n",
        "\n",
        "Regarding pruning configuration, the same settings apply to all prunable layers in the model.\n",
        "\n",
        "Also noteworthy is that pruning doesn't preserve the optimizer associated with the original model. As a result, it is necessary to re-compile the pruned model with a new optimizer. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "I-W7Sj8fZjeb"
      },
      "source": [
        "Before we move forward with the example, lets address the common use case where you may already have a serialized pre-trained Keras model, which you would like to apply weight pruning on. We will take the original MNIST model trained previously to show how this works. In this case, you start by loading the model into memory like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "qJm1SJxfqy2e"
      },
      "outputs": [],
      "source": [
        "# Load the serialized model\n",
        "loaded_model = tf.keras.models.load_model(keras_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "etYrWnrSMpB5"
      },
      "source": [
        "Then you can prune the model loaded and compile the pruned model for training. In this case training will restart from step 0. Given the model we loadded already reached a satisfactory accuracy, we can start pruning immediately. As a result, we set the begin_step to 0 here, and only train for another four epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CabnOldzrSN2"
      },
      "outputs": [],
      "source": [
        "epochs = 4\n",
        "end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs\n",
        "print(end_step)\n",
        "\n",
        "new_pruning_params = {\n",
        "      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,\n",
        "                                                   final_sparsity=0.90,\n",
        "                                                   begin_step=0,\n",
        "                                                   end_step=end_step,\n",
        "                                                   frequency=100)\n",
        "}\n",
        "\n",
        "new_pruned_model = sparsity.prune_low_magnitude(model, **new_pruning_params)\n",
        "new_pruned_model.summary()\n",
        "\n",
        "new_pruned_model.compile(\n",
        "    loss=tf.keras.losses.categorical_crossentropy,\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9qCipfCnaY7g"
      },
      "source": [
        "Load tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VPBlcTXWB9tx"
      },
      "outputs": [],
      "source": [
        "logdir = tempfile.mkdtemp()\n",
        "print('Writing training logs to ' + logdir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X43ix4NZCDNS"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir={logdir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "r2hLPO7KNKq_"
      },
      "source": [
        "### Train the model for another four epochs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "36hymxokrnbw"
      },
      "outputs": [],
      "source": [
        "# Add a pruning step callback to peg the pruning step to the optimizer's\n",
        "# step. Also add a callback to add pruning summaries to tensorboard\n",
        "callbacks = [\n",
        "    sparsity.UpdatePruningStep(),\n",
        "    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)\n",
        "]\n",
        "\n",
        "new_pruned_model.fit(x_train, y_train,\n",
        "          batch_size=batch_size,\n",
        "          epochs=epochs,\n",
        "          verbose=1,\n",
        "          callbacks=callbacks,\n",
        "          validation_data=(x_test, y_test))\n",
        "\n",
        "score = new_pruned_model.evaluate(x_test, y_test, verbose=0)\n",
        "print('Test loss:', score[0])\n",
        "print('Test accuracy:', score[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Z0yphbuRarfT"
      },
      "source": [
        "### Export the pruned model for serving"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "anFHUpCXrxMe"
      },
      "outputs": [],
      "source": [
        "final_model = sparsity.strip_pruning(pruned_model)\n",
        "final_model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4CmEtxHEso7g"
      },
      "outputs": [],
      "source": [
        "_, new_pruned_keras_file = tempfile.mkstemp('.h5')\n",
        "print('Saving pruned model to: ', new_pruned_keras_file)\n",
        "tf.keras.models.save_model(final_model, new_pruned_keras_file, \n",
        "                        include_optimizer=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YT8YqXAza3Tt"
      },
      "source": [
        "The model size after compression is the same as the one pruned layer-by-layer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "AtKANC0hs2RR"
      },
      "outputs": [],
      "source": [
        "_, zip3 = tempfile.mkstemp('.zip')\n",
        "with zipfile.ZipFile(zip3, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
        "  f.write(new_pruned_keras_file)\n",
        "print(\"Size of the pruned model before compression: %.2f Mb\" \n",
        "      % (os.path.getsize(new_pruned_keras_file) / float(2**20)))\n",
        "print(\"Size of the pruned model after compression: %.2f Mb\" \n",
        "      % (os.path.getsize(zip3) / float(2**20)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zXrLGUPfIwvV"
      },
      "source": [
        "## Convert to TensorFlow Lite\n",
        "\n",
        "Finally, you can convert the pruned model to a format that's runnable on your targeting backend. Tensorflow Lite is an example format you can use to deploy to mobile devices. To convert to a Tensorflow Lite graph, you need to use the TFLiteConverter as below:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1f9Eb2K0bcJG"
      },
      "source": [
        "### Convert the model with TFLiteConverter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Ctqfiix-H-x7"
      },
      "outputs": [],
      "source": [
        "tflite_model_file = '/tmp/sparse_mnist.tflite'\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file)\n",
        "tflite_model = converter.convert()\n",
        "with open(tflite_model_file, 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fdIiNnrPANcw"
      },
      "source": [
        "### Size of the TensorFlow Lite model after compression"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iYMSXAU1AUYI"
      },
      "outputs": [],
      "source": [
        "_, zip_tflite = tempfile.mkstemp('.zip')\n",
        "with zipfile.ZipFile(zip_tflite, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
        "  f.write(tflite_model_file)\n",
        "print(\"Size of the tflite model before compression: %.2f Mb\" \n",
        "      % (os.path.getsize(tflite_model_file) / float(2**20)))\n",
        "print(\"Size of the tflite model after compression: %.2f Mb\" \n",
        "      % (os.path.getsize(zip_tflite) / float(2**20)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vBqUX1qopV1k"
      },
      "source": [
        "### Evaluate the accuracy of the TensorFlow Lite model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "F5AY-TuivmbP"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
        "interpreter.allocate_tensors()\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]\n",
        "\n",
        "def eval_model(interpreter, x_test, y_test):\n",
        "  total_seen = 0\n",
        "  num_correct = 0\n",
        "\n",
        "  for img, label in zip(x_test, y_test):\n",
        "    inp = img.reshape((1, 28, 28, 1))\n",
        "    total_seen += 1\n",
        "    interpreter.set_tensor(input_index, inp)\n",
        "    interpreter.invoke()\n",
        "    predictions = interpreter.get_tensor(output_index)\n",
        "    if np.argmax(predictions) == np.argmax(label):\n",
        "      num_correct += 1\n",
        "\n",
        "    if total_seen % 1000 == 0:\n",
        "        print(\"Accuracy after %i images: %f\" %\n",
        "              (total_seen, float(num_correct) / float(total_seen)))\n",
        "\n",
        "  return float(num_correct) / float(total_seen)\n",
        "\n",
        "print(eval_model(interpreter, x_test, y_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Zalu5Ng7D7xi"
      },
      "source": [
        "### Post-training quantize the TensorFlow Lite model\n",
        "\n",
        "You can combine pruning with other optimization techniques like post training quantization. As a recap, post-training quantization converts weights to 8 bit precision as part of model conversion from keras model to TFLite's flat buffer, resulting in a 4x reduction in the model size.\n",
        "\n",
        "In the following example, we take the pruned keras model, convert it with post-training quantization, check the size reduction and validate its accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wbTNqf1KER0z"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file)\n",
        "\n",
        "converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]\n",
        "\n",
        "tflite_quant_model = converter.convert()\n",
        "\n",
        "tflite_quant_model_file = '/tmp/sparse_mnist_quant.tflite'\n",
        "with open(tflite_quant_model_file, 'wb') as f:\n",
        "  f.write(tflite_quant_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "s6a8bH_-EXeR"
      },
      "outputs": [],
      "source": [
        "_, zip_tflite = tempfile.mkstemp('.zip')\n",
        "with zipfile.ZipFile(zip_tflite, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
        "  f.write(tflite_quant_model_file)\n",
        "print(\"Size of the tflite model before compression: %.2f Mb\" \n",
        "      % (os.path.getsize(tflite_quant_model_file) / float(2**20)))\n",
        "print(\"Size of the tflite model after compression: %.2f Mb\" \n",
        "      % (os.path.getsize(zip_tflite) / float(2**20)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lMuuxZs_QxMt"
      },
      "source": [
        "The size of the quantized model is roughly 1/4 of the orignial one."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Dkv9mauCEami"
      },
      "outputs": [],
      "source": [
        "interpreter = tf.lite.Interpreter(model_path=str(tflite_quant_model_file))\n",
        "interpreter.allocate_tensors()\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]\n",
        "\n",
        "print(eval_model(interpreter, x_test, y_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ygzK3KZcoZ-w"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "In this tutorial, we showed you how to create *sparse models* with the TensorFlow model optimization toolkit weight pruning API. Right now, this allows you to create models that take significant less space on disk. The resulting model can also be more efficiently implemented to avoid computation; in the future TensorFlow Lite will provide such capabilities.\n",
        "\n",
        "More specifically, we walked you through an end-to-end example of training a simple MNIST model that used the weight pruning API. We showed you how to convert it to the Tensorflow Lite format for mobile deployment, and demonstrated how with simple file compression the model size was reduced 5x.\n",
        "\n",
        "We encourage you to try this new capability on your Keras models, which can be particularly important for deployment in resource-constraint environments. \n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "JCkPHee-alKM"
      ],
      "name": "pruning_with_keras.ipynb",
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
